{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# [HW2] Problem : Outlier Removal via OMP (part 2)\n",
                "\n",
                "This problem is the second part and a continuation from the OMP problem from HW1. We will look into why the concept of validation set is needed here as well as how to use a validation set that also contains outliers.\n",
                "\n",
                "## (a)\n",
                "First, we will run some demo from the first part of the problem that you have already seen.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import numpy as np\n",
                "from helpers import make_state_trace, identify_system, random_input, plot_eigenvalues\n",
                "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
                "from sklearn.linear_model import OrthogonalMatchingPursuit\n",
                "import scipy.io\n",
                "\n",
                "%matplotlib inline\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def greedy_OMP(X_aug, y, k):\n",
                "    \"\"\"\n",
                "    Performs orthogonal matching pursuit.\n",
                "\n",
                "    args:\n",
                "      X_aug: augmented data matrix, i.e., X_aug = [X I]\n",
                "      y: noisy observations with outliers\n",
                "      k: how many non-zero entries in solution of OMP; is equal to\n",
                "        the number of iterations in the OMP algorithm\n",
                "\n",
                "    returns:\n",
                "      idx: the size k set of ignored data points\n",
                "      plot_points: the prediction points for the two points x=0 and\n",
                "        x = 10.0; for use of plotting\n",
                "      remaining_residual: amount of residual remaining\n",
                "    \"\"\"\n",
                "\n",
                "    reg = OrthogonalMatchingPursuit(\n",
                "            n_nonzero_coefs=k, fit_intercept=False).fit(X_aug,y)\n",
                "    idx = [j for (i, j) in zip(reg.coef_, range(1, len(reg.coef_) + 1))\n",
                "           if abs(i) > 1e-7]\n",
                "\n",
                "    test_mat = 0.0 * np.zeros((2, X_aug.shape[1]))\n",
                "    test_mat[1, 0] = 10.0\n",
                "    plot_points = reg.predict(test_mat)\n",
                "\n",
                "    y = np.reshape(y, (y.shape[0], 1))\n",
                "    err_vec = (y - np.reshape(reg.predict(X_aug), (y.shape[0], 1))) ** 2\n",
                "    remaining_residual = np.mean(err_vec)\n",
                "    return remaining_residual, reg.coef_\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Load data\n",
                "mdict = scipy.io.loadmat(\"data.mat\")\n",
                "\n",
                "X, y = mdict[\"X_train\"], mdict[\"y_train\"].flatten()\n",
                "X_test, y_test = mdict[\"X_test\"], mdict[\"y_test\"].flatten()\n",
                "X_val, y_val = mdict[\"X_val\"], mdict[\"y_val\"].flatten()\n",
                "\n",
                "n = 50\n",
                "d = 1\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "errs = []\n",
                "weights = []\n",
                "X_aug = np.concatenate((X, np.eye(X.shape[0])), axis=1)\n",
                "num = n\n",
                "idxes_size = -1\n",
                "\n",
                "while(num > 1):\n",
                "    err, w = greedy_OMP(X_aug, y, n - num + 2)\n",
                "    errs.append(err)\n",
                "    weights.append(w)\n",
                "    num -= 1\n",
                "\n",
                "plt.figure(figsize=(10, 3))\n",
                "\n",
                "plt.subplot(121)\n",
                "plt.plot(np.arange(1, len(errs) + 1), errs)\n",
                "plt.title(\"Error vs number of points removed\")\n",
                "\n",
                "plt.subplot(122)\n",
                "plt.plot(np.arange(1, len(errs) + 1), errs)\n",
                "plt.yscale('log')\n",
                "plt.title(\"Error (log scale) vs number of points removed\")\n",
                "\n",
                "plt.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The left plot, which we have seen in the previous homework, is the error incurred by each OMP solution with different numbers of outliers removed. It is difficult to see in the normal scale, but the plot in the log scale (right) shows that the training error always keeps decreasing! Eventually, if we treat every data point as an outlier and remove all but one, we technically reach zero error. However, this is not what we want as we are estimating the true weights from a single data point instead of using all the uncorrupted points.\n",
                "\n",
                "Next, we will show that this is indeed a bad thing to do when we evaluate the OMP solutions we obtained earlier on a clean test set. This is the data distribution that we want to perform well on, not the corrupted one we have for training. **Fill in the space below to compute the errors on this test set using all the weights we obtained from OMP with different numbers of outliers removed.** The vector test_errs should contain the mean squared error on the test set, indexed by how many purported outliers have been removed.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "test_errs = []\n",
                "\n",
                "### start 1 ###\n",
                "\n",
                "### end 1 ###\n",
                "\n",
                "plt.figure(figsize=(5, 3))\n",
                "plt.plot(np.arange(1, len(test_errs) + 1), test_errs, color='orange')\n",
                "plt.title(\"Test error vs number of outliers removed\")\n",
                "plt.yscale('log')\n",
                "plt.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "**Describe in words the trend of the test error you see, and based on the plot alone, what is a good number of outliers to be removed by OMP?**\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## (b)\n",
                "\n",
                "**Compute the validation errors based on some of weight vectors obtained from OMP (`weights` at indices 0, 10, 20, 30, and 40). For each weight, plot a histogram of the errors, and compute their mean and median.** See where the values of mean and median fall in each histogram.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "for idx in [0, 10, 20, 30, 40]:\n",
                "    ### start 2 ###\n",
                "\n",
                "    ### end 2 ###\n",
                "    mean = np.mean(val_err)\n",
                "    median = np.median(val_err)\n",
                "    plt.hist(val_err)\n",
                "    plt.show()\n",
                "    print('k=%d, mean: %.4f, median: %.4f.' % (idx + 2, mean, median))\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## (c)\n",
                "\n",
                "Does the idea of median being a robust estimator give you some idea for \"robust\" validation? **Plot mean and median of the validation errors with respect to the number of removed outliers. Now determine from the median plot where to stop. Why does median work better mean in this case?**\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "mean_val_errs = []\n",
                "median_val_errs = []\n",
                "\n",
                "### start 3 ###\n",
                "\n",
                "### end 3 ###\n",
                "\n",
                "plt.figure(figsize=(5, 3))\n",
                "plt.plot(np.arange(1, len(mean_val_errs) + 1), mean_val_errs, label=\"mean\")\n",
                "plt.plot(np.arange(1, len(median_val_errs) + 1), median_val_errs, label=\"median\")\n",
                "plt.title(\"Validation error vs number of augmented bases\")\n",
                "plt.yscale('log')\n",
                "plt.legend()\n",
                "plt.tight_layout()\n",
                "plt.show()\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.3"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 4
}